-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AdaFactor: avoid updating group["lr"] attributes #9751
Conversation
This affects Adafactor with relative_step=False and scale_parameter=True. Updating group["lr"] makes the result of ._get_lr() depends on the previous call, i.e., on the scale of other parameters. This isn't supposed to happen.
Can you provide evidence that supports the following:
Thanks! |
Hi, Thanks for the quick reply. This is taken from the AdaFactor paper: As you can see, ρ only depends on the step number if we use relative steps. And if we switch to any other learning rate schedules (in my case, linear warmup + cosine decay), it doesn't make sense to make the ρ part depends on the scale of the other parameters, nor can I find any reference of this approach in the paper. If we (loosely) factor the αt in the original implementation to αi,t, where |
I should probably clarify what I meant by "weird behaviors." The model (T5 v1.1) never converged when trained Adafactor with |
cc @patrickvonplaten @patil-suraj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree very much with your explanation here @ceshine - that's a great fix, thanks!
BTW, if you have some working code for how to train a google/t5v1_1
model I think it would be super helpful to post it here, on the forum or as a community notebook! Many people have been asking for good t5v1_1 training scripts :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thank you!
It doesn't look like any other entry in group
gets modified.
Ideally in such situation it's a great opportunity to add a test that detects the problem - i.e. lack of convergence, I can imagine this would be quite tricky to accomplish!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your explanation and providing references. LGTM.
Thank you all for your time and for accepting the patch! Glad to have made a tiny contribution to this great library.
I don't have anything that is sufficiently readable yet. Nonetheless, I have these notebooks published on Kaggle that use the patched Adafactor: one for T5 v1.1 and one for mT5. They are based on this Github repo, which is quite messy at this moment. The part that set up the optimizer is located here. |
This affects Adafactor with
relative_step=False
andscale_parameter=True
.Updating
group["lr"]
makes the result of ._get_lr() depends on the previous call, i.e., on the scale of other parameters. This isn't supposed to happen.What does this PR do?
I've observed weird behaviors when using Adafactor with
relative_step=False
andscale_parameter=True
and an LR scheduler. I think the problem is that the code updates thelr
attribute of the current parameter group, and then uses the updated attribute to calculate the next attribute. I don't think this is supposed to happen.A simple fix would be replacing the update operation with an assignment to a local variable.
I'm not entirely sure if I understand the problem correctly, so I apologize in advance if this is a stupid PR. I'd appreciate it if someone could point out where I am wrong. Thanks!
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@moscow25 @sshleifer